!--------------------------------------------------------------------------------------------------!
!   CP2K: A general program to perform molecular dynamics simulations                              !
!   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
!                                                                                                  !
!   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
!--------------------------------------------------------------------------------------------------!

! **************************************************************************************************
!> \brief routines that build the Kohn-Sham matrix  contributions
!>      coming from local atomic densities
! **************************************************************************************************
MODULE qs_ks_atom

   USE ao_util,                         ONLY: trace_r_AxB
   USE atomic_kind_types,               ONLY: atomic_kind_type,&
                                              get_atomic_kind,&
                                              get_atomic_kind_set
   USE basis_set_types,                 ONLY: gto_basis_set_p_type,&
                                              gto_basis_set_type
   USE cp_array_utils,                  ONLY: cp_2d_r_p_type
   USE cp_control_types,                ONLY: dft_control_type
   USE cp_dbcsr_api,                    ONLY: dbcsr_get_block_p,&
                                              dbcsr_p_type
   USE kinds,                           ONLY: dp,&
                                              int_8
   USE kpoint_types,                    ONLY: get_kpoint_info,&
                                              kpoint_type
   USE message_passing,                 ONLY: mp_para_env_type
   USE qs_environment_types,            ONLY: get_qs_env,&
                                              qs_environment_type
   USE qs_force_types,                  ONLY: qs_force_type
   USE qs_kind_types,                   ONLY: get_qs_kind,&
                                              get_qs_kind_set,&
                                              qs_kind_type
   USE qs_neighbor_list_types,          ONLY: get_iterator_task,&
                                              neighbor_list_iterate,&
                                              neighbor_list_iterator_create,&
                                              neighbor_list_iterator_p_type,&
                                              neighbor_list_iterator_release,&
                                              neighbor_list_set_p_type,&
                                              neighbor_list_task_type
   USE qs_nl_hash_table_types,          ONLY: nl_hash_table_add,&
                                              nl_hash_table_create,&
                                              nl_hash_table_get_from_index,&
                                              nl_hash_table_is_null,&
                                              nl_hash_table_obj,&
                                              nl_hash_table_release,&
                                              nl_hash_table_status
   USE qs_oce_methods,                  ONLY: prj_gather
   USE qs_oce_types,                    ONLY: oce_matrix_type
   USE qs_rho_atom_types,               ONLY: get_rho_atom,&
                                              rho_atom_coeff,&
                                              rho_atom_type
   USE sap_kind_types,                  ONLY: alist_post_align_blk,&
                                              alist_pre_align_blk,&
                                              alist_type,&
                                              get_alist
   USE util,                            ONLY: get_limit
   USE virial_methods,                  ONLY: virial_pair_force
   USE virial_types,                    ONLY: virial_type

!$ USE OMP_LIB, ONLY: omp_get_max_threads, &
!$                    omp_get_thread_num, &
!$                    omp_lock_kind, &
!$                    omp_init_lock, omp_set_lock, &
!$                    omp_unset_lock, omp_destroy_lock

#include "./base/base_uses.f90"

   IMPLICIT NONE

   PRIVATE

   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'qs_ks_atom'

   PUBLIC :: update_ks_atom

CONTAINS

! **************************************************************************************************
!> \brief The correction to the KS matrix due to the GAPW local terms to the hartree and
!>      XC contributions is here added. The correspondig forces contribution are also calculated
!>      if required. Each sparse-matrix block A-B is corrected by all the atomic contributions
!>      centered on atoms C for which the triplet A-C-B exists (they are close enough)
!>      To this end special lists are used
!> \param qs_env qs environment, for the lists, the contraction coefficients and the
!>               pre calculated integrals of the potential with the atomic orbitals
!> \param ksmat KS matrix, sparse matrix
!> \param pmat density matrix, sparse matrix, needed only for the forces
!> \param forces switch for the calculation of the forces on atoms
!> \param tddft switch for TDDFT linear response
!> \param rho_atom_external ...
!> \param kind_set_external ...
!> \param oce_external ...
!> \param sab_external ...
!> \param kscale ...
!> \param kintegral ...
!> \param kforce ...
!> \param fscale ...
!> \par History
!>      created [MI]
!>      the loop over the spins is done internally [03-05,MI]
!>      Rewrite using new OCE matrices [08.02.09, jhu]
!>      Add OpenMP [Apr 2016, EPCC]
!>      Allow for external kind_set, rho_atom_set, oce, sab 12.2019 (A. Bussy)
! **************************************************************************************************
   SUBROUTINE update_ks_atom(qs_env, ksmat, pmat, forces, tddft, rho_atom_external, &
                             kind_set_external, oce_external, sab_external, kscale, &
                             kintegral, kforce, fscale)

      TYPE(qs_environment_type), POINTER                 :: qs_env
      TYPE(dbcsr_p_type), DIMENSION(*), INTENT(INOUT)    :: ksmat, pmat
      LOGICAL, INTENT(IN)                                :: forces
      LOGICAL, INTENT(IN), OPTIONAL                      :: tddft
      TYPE(rho_atom_type), DIMENSION(:), OPTIONAL, &
         POINTER                                         :: rho_atom_external
      TYPE(qs_kind_type), DIMENSION(:), OPTIONAL, &
         POINTER                                         :: kind_set_external
      TYPE(oce_matrix_type), OPTIONAL, POINTER           :: oce_external
      TYPE(neighbor_list_set_p_type), DIMENSION(:), &
         OPTIONAL, POINTER                               :: sab_external
      REAL(KIND=dp), INTENT(IN), OPTIONAL                :: kscale, kintegral, kforce
      REAL(KIND=dp), DIMENSION(2), INTENT(IN), OPTIONAL  :: fscale

      CHARACTER(len=*), PARAMETER                        :: routineN = 'update_ks_atom'

      INTEGER :: bo(2), handle, ia_kind, iac, iat, iatom, ibc, ikind, img, ip, ispin, ja_kind, &
         jatom, jkind, ka_kind, kac, katom, kbc, kkind, ldCPC, max_gau, max_nsgf, n_cont_a, &
         n_cont_b, nat, natom, nimages, nkind, nl_table_num_slots, nsoctot, nspins, slot_num
      INTEGER(KIND=int_8)                                :: nl_table_key
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: atom_of_kind
      INTEGER, DIMENSION(3)                              :: cell_b
      INTEGER, DIMENSION(:), POINTER                     :: atom_list, list_a, list_b
      INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
      LOGICAL                                            :: dista, distb, found, is_entry_null, &
                                                            is_task_valid, my_tddft, paw_atom, &
                                                            use_virial
      REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: a_matrix, p_matrix
      REAL(dp), DIMENSION(3)                             :: rac, rbc
      REAL(dp), DIMENSION(3, 3)                          :: force_tmp
      REAL(kind=dp)                                      :: eps_cpc, factor1, factor2
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: C_int_h, C_int_s, coc
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: dCPC_h, dCPC_s
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :)     :: PC_h, PC_s
      REAL(KIND=dp), DIMENSION(2)                        :: force_fac
      REAL(KIND=dp), DIMENSION(3, 3)                     :: pv_virial_thread
      REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: C_coeff_hh_a, C_coeff_hh_b, &
                                                            C_coeff_ss_a, C_coeff_ss_b
      TYPE(alist_type), POINTER                          :: alist_ac, alist_bc
      TYPE(atomic_kind_type), DIMENSION(:), POINTER      :: atomic_kind_set
      TYPE(cp_2d_r_p_type), DIMENSION(:), POINTER        :: mat_h, mat_p
      TYPE(dft_control_type), POINTER                    :: dft_control
      TYPE(gto_basis_set_p_type), DIMENSION(:), POINTER  :: basis_set_list
      TYPE(gto_basis_set_type), POINTER                  :: basis_set_a, basis_set_b
      TYPE(kpoint_type), POINTER                         :: kpoints
      TYPE(mp_para_env_type), POINTER                    :: para_env
      TYPE(neighbor_list_iterator_p_type), &
         DIMENSION(:), POINTER                           :: nl_iterator
      TYPE(neighbor_list_set_p_type), DIMENSION(:), &
         POINTER                                         :: my_sab
      TYPE(neighbor_list_task_type), POINTER             :: next_task, task
      TYPE(nl_hash_table_obj)                            :: nl_hash_table
      TYPE(oce_matrix_type), POINTER                     :: my_oce
      TYPE(qs_force_type), DIMENSION(:), POINTER         :: force
      TYPE(qs_kind_type), DIMENSION(:), POINTER          :: my_kind_set
      TYPE(rho_atom_coeff), DIMENSION(:), POINTER        :: int_local_h, int_local_s
      TYPE(rho_atom_type), DIMENSION(:), POINTER         :: my_rho_atom
      TYPE(rho_atom_type), POINTER                       :: rho_at
      TYPE(virial_type), POINTER                         :: virial

!$    INTEGER(kind=omp_lock_kind), ALLOCATABLE, DIMENSION(:) :: locks
!$    INTEGER                                            :: lock_num

      CALL timeset(routineN, handle)

      NULLIFY (my_kind_set, atomic_kind_set, force, my_oce, para_env, my_rho_atom, my_sab)
      NULLIFY (mat_h, mat_p, dft_control)

      CALL get_qs_env(qs_env=qs_env, &
                      qs_kind_set=my_kind_set, &
                      atomic_kind_set=atomic_kind_set, &
                      force=force, &
                      oce=my_oce, &
                      para_env=para_env, &
                      rho_atom_set=my_rho_atom, &
                      virial=virial, &
                      sab_orb=my_sab, &
                      dft_control=dft_control)

      nspins = dft_control%nspins
      nimages = dft_control%nimages

      factor1 = 1.0_dp
      factor2 = 1.0_dp

      !deal with externals
      my_tddft = .FALSE.
      IF (PRESENT(tddft)) my_tddft = tddft
      IF (my_tddft) THEN
         IF (nspins == 1) factor1 = 2.0_dp
         CPASSERT(nimages == 1)
      END IF
      IF (PRESENT(kscale)) THEN
         factor1 = factor1*kscale
         factor2 = factor2*kscale
      END IF
      IF (PRESENT(kintegral)) factor1 = kintegral
      IF (PRESENT(kforce)) factor2 = kforce
      force_fac = 1.0_dp
      IF (PRESENT(fscale)) force_fac(:) = fscale(:)

      IF (PRESENT(rho_atom_external)) my_rho_atom => rho_atom_external
      IF (PRESENT(kind_set_external)) my_kind_set => kind_set_external
      IF (PRESENT(oce_external)) my_oce => oce_external
      IF (PRESENT(sab_external)) my_sab => sab_external

      ! kpoint images
      NULLIFY (cell_to_index)
      IF (nimages > 1) THEN
         CALL get_qs_env(qs_env=qs_env, kpoints=kpoints)
         CALL get_kpoint_info(kpoint=kpoints, cell_to_index=cell_to_index)
      END IF

      eps_cpc = dft_control%qs_control%gapw_control%eps_cpc

      CALL get_atomic_kind_set(atomic_kind_set, natom=natom)
      CALL get_qs_kind_set(my_kind_set, maxsgf=max_nsgf, maxgtops=max_gau, basis_type="GAPW_1C")

      IF (forces) THEN
         CALL get_atomic_kind_set(atomic_kind_set, atom_of_kind=atom_of_kind)
         ldCPC = max_gau
         use_virial = virial%pv_availability .AND. (.NOT. virial%pv_numer)
      ELSE
         use_virial = .FALSE.
      END IF

      pv_virial_thread(:, :) = 0.0_dp ! always initialize to avoid floating point exceptions in OMP REDUCTION

      nkind = SIZE(my_kind_set, 1)
      ! Collect the local potential contributions from all the processors
      ASSOCIATE (mepos => para_env%mepos, num_pe => para_env%num_pe)
      DO ikind = 1, nkind
         NULLIFY (atom_list)
         CALL get_atomic_kind(atomic_kind_set(ikind), atom_list=atom_list, natom=nat)
         CALL get_qs_kind(my_kind_set(ikind), paw_atom=paw_atom)
         IF (paw_atom) THEN
            ! gather the atomic block integrals in a more compressed format
            bo = get_limit(nat, num_pe, mepos)
            DO iat = bo(1), bo(2)
               iatom = atom_list(iat)
               DO ispin = 1, nspins
                  CALL prj_gather(my_rho_atom(iatom)%ga_Vlocal_gb_h(ispin)%r_coef, &
                                  my_rho_atom(iatom)%int_scr_h(ispin)%r_coef, my_kind_set(ikind))
                  CALL prj_gather(my_rho_atom(iatom)%ga_Vlocal_gb_s(ispin)%r_coef, &
                                  my_rho_atom(iatom)%int_scr_s(ispin)%r_coef, my_kind_set(ikind))
               END DO
            END DO
            ! broadcast the CPC arrays to all processors (replicated data)
            DO ip = 0, num_pe - 1
               bo = get_limit(nat, num_pe, ip)
               DO iat = bo(1), bo(2)
                  iatom = atom_list(iat)
                  DO ispin = 1, nspins
                     CALL para_env%bcast(my_rho_atom(iatom)%int_scr_h(ispin)%r_coef, ip)
                     CALL para_env%bcast(my_rho_atom(iatom)%int_scr_s(ispin)%r_coef, ip)
                  END DO
               END DO
            END DO
         END IF
      END DO
      END ASSOCIATE

      ALLOCATE (basis_set_list(nkind))
      DO ikind = 1, nkind
         CALL get_qs_kind(my_kind_set(ikind), basis_set=basis_set_a)
         IF (ASSOCIATED(basis_set_a)) THEN
            basis_set_list(ikind)%gto_basis_set => basis_set_a
         ELSE
            NULLIFY (basis_set_list(ikind)%gto_basis_set)
         END IF
      END DO

      ! build the hash table in serial...
      ! ... creation ...
      CALL neighbor_list_iterator_create(nl_iterator, my_sab)
      nl_table_num_slots = natom*natom/2 ! this is probably not optimal, but it seems a good start
      CALL nl_hash_table_create(nl_hash_table, nmax=nl_table_num_slots)
      ! ... and population
      DO WHILE (neighbor_list_iterate(nl_iterator) == 0) ! build hash table in serial, so don't pass mepos
         ALLOCATE (task) ! They must be deallocated before the nl_hash_table is released
         CALL get_iterator_task(nl_iterator, task) ! build hash table in serial, so don't pass mepos
         ! tasks with the same key access the same blocks of H & P
         IF (task%iatom <= task%jatom) THEN
            nl_table_key = natom*task%iatom + task%jatom
         ELSE
            nl_table_key = natom*task%jatom + task%iatom
         END IF
         CALL nl_hash_table_add(nl_hash_table, nl_table_key, task)
      END DO
      CALL neighbor_list_iterator_release(nl_iterator)

      ! Get the total number of (possibly empty) entries in the table
      CALL nl_hash_table_status(nl_hash_table, nmax=nl_table_num_slots)

!$OMP PARALLEL DEFAULT(NONE)                                         &
!$OMP           SHARED(nl_table_num_slots, nl_hash_table             &
!$OMP                 , max_gau, max_nsgf, nspins, forces            &
!$OMP                 , basis_set_list, nimages, cell_to_index       &
!$OMP                 , ksmat, pmat, natom, nkind, my_kind_set, my_oce &
!$OMP                 , my_rho_atom, factor1, factor2, use_virial    &
!$OMP                 , atom_of_kind, ldCPC, force, locks, force_fac &
!$OMP                 )                                              &
!$OMP          PRIVATE( slot_num, is_entry_null, TASK, is_task_valid &
!$OMP                 , C_int_h, C_int_s, coc, a_matrix, p_matrix    &
!$OMP                 , mat_h, mat_p, dCPC_h, dCPC_s, PC_h, PC_s     &
!$OMP                 , int_local_h, int_local_s                     &
!$OMP                 , ikind, jkind, iatom, jatom, cell_b           &
!$OMP                 , basis_set_a, basis_set_b, img                &
!$OMP                 , found, next_task                             &
!$OMP                 , kkind, paw_atom, lock_num                    &
!$OMP                 , iac, alist_ac, kac, n_cont_a, list_a         &
!$OMP                 , ibc, alist_bc, kbc, n_cont_b, list_b         &
!$OMP                 , katom, rho_at, nsoctot                       &
!$OMP                 , C_coeff_hh_a, C_coeff_ss_a, dista, rac       &
!$OMP                 , C_coeff_hh_b, C_coeff_ss_b, distb, rbc       &
!$OMP                 , ia_kind, ja_kind, ka_kind, force_tmp         &
!$OMP                 )                                              &
!$OMP        REDUCTION(+:pv_virial_thread                            &
!$OMP                 )

      ALLOCATE (C_int_h(max_gau*max_nsgf), C_int_s(max_gau*max_nsgf), coc(max_gau*max_gau), &
                a_matrix(max_gau, max_gau), p_matrix(max_nsgf, max_nsgf))

      ALLOCATE (mat_h(nspins), mat_p(nspins))
      DO ispin = 1, nspins
         NULLIFY (mat_h(ispin)%array, mat_p(ispin)%array)
      END DO

      IF (forces) THEN
         ALLOCATE (dCPC_h(max_gau, max_gau), dCPC_s(max_gau, max_gau), &
                   PC_h(max_nsgf, max_gau, nspins), PC_s(max_nsgf, max_gau, nspins))
!$OMP SINGLE
!$       ALLOCATE (locks(natom*nkind))
!$OMP END SINGLE

!$OMP DO
!$       do lock_num = 1, natom*nkind
!$          call omp_init_lock(locks(lock_num))
!$       end do
!$OMP END DO
      END IF

      ! Dynamic schedule to take account of the fact that some slots may be empty
      ! or contain 1 or more tasks
!$OMP DO SCHEDULE(DYNAMIC,5)
      DO slot_num = 1, nl_table_num_slots
         CALL nl_hash_table_is_null(nl_hash_table, slot_num, is_entry_null)

         IF (.NOT. is_entry_null) THEN
            CALL nl_hash_table_get_from_index(nl_hash_table, slot_num, task)

            is_task_valid = ASSOCIATED(task)
            DO WHILE (is_task_valid)

               ikind = task%ikind
               jkind = task%jkind
               iatom = task%iatom
               jatom = task%jatom
               cell_b(:) = task%cell(:)

               basis_set_a => basis_set_list(ikind)%gto_basis_set
               IF (.NOT. ASSOCIATED(basis_set_a)) CYCLE
               basis_set_b => basis_set_list(jkind)%gto_basis_set
               IF (.NOT. ASSOCIATED(basis_set_b)) CYCLE

               IF (nimages > 1) THEN
                  img = cell_to_index(cell_b(1), cell_b(2), cell_b(3))
                  CPASSERT(img > 0)
               ELSE
                  img = 1
               END IF

               DO ispin = 1, nspins
                  NULLIFY (mat_h(ispin)%array, mat_p(ispin)%array)

                  found = .FALSE.
                  IF (iatom <= jatom) THEN
                     CALL dbcsr_get_block_p(matrix=ksmat(nspins*(img - 1) + ispin)%matrix, &
                                            row=iatom, col=jatom, &
                                            BLOCK=mat_h(ispin)%array, found=found)
                  ELSE
                     CALL dbcsr_get_block_p(matrix=ksmat(nspins*(img - 1) + ispin)%matrix, &
                                            row=jatom, col=iatom, &
                                            BLOCK=mat_h(ispin)%array, found=found)
                  END IF
                  CPASSERT(found)

                  IF (forces) THEN
                     found = .FALSE.
                     IF (iatom <= jatom) THEN
                        CALL dbcsr_get_block_p(matrix=pmat(nspins*(img - 1) + ispin)%matrix, &
                                               row=iatom, col=jatom, &
                                               BLOCK=mat_p(ispin)%array, found=found)
                     ELSE
                        CALL dbcsr_get_block_p(matrix=pmat(nspins*(img - 1) + ispin)%matrix, &
                                               row=jatom, col=iatom, &
                                               BLOCK=mat_p(ispin)%array, found=found)
                     END IF
                     CPASSERT(found)
                  END IF
               END DO

               DO kkind = 1, nkind
                  CALL get_qs_kind(my_kind_set(kkind), paw_atom=paw_atom)

                  IF (.NOT. paw_atom) CYCLE

                  iac = ikind + nkind*(kkind - 1)
                  ibc = jkind + nkind*(kkind - 1)
                  IF (.NOT. ASSOCIATED(my_oce%intac(iac)%alist)) CYCLE
                  IF (.NOT. ASSOCIATED(my_oce%intac(ibc)%alist)) CYCLE

                  CALL get_alist(my_oce%intac(iac), alist_ac, iatom)
                  CALL get_alist(my_oce%intac(ibc), alist_bc, jatom)
                  IF (.NOT. ASSOCIATED(alist_ac)) CYCLE
                  IF (.NOT. ASSOCIATED(alist_bc)) CYCLE

                  DO kac = 1, alist_ac%nclist
                     DO kbc = 1, alist_bc%nclist
                        IF (alist_ac%clist(kac)%catom /= alist_bc%clist(kbc)%catom) CYCLE

                        IF (ALL(cell_b + alist_bc%clist(kbc)%cell - alist_ac%clist(kac)%cell == 0)) THEN
                           n_cont_a = alist_ac%clist(kac)%nsgf_cnt
                           n_cont_b = alist_bc%clist(kbc)%nsgf_cnt
                           IF (n_cont_a == 0 .OR. n_cont_b == 0) CYCLE

                           list_a => alist_ac%clist(kac)%sgf_list
                           list_b => alist_bc%clist(kbc)%sgf_list

                           katom = alist_ac%clist(kac)%catom

                           IF (iatom == katom .AND. ALL(alist_ac%clist(kac)%cell == 0)) THEN
                              C_coeff_hh_a => alist_ac%clist(kac)%achint
                              C_coeff_ss_a => alist_ac%clist(kac)%acint
                              dista = .FALSE.
                           ELSE
                              C_coeff_hh_a => alist_ac%clist(kac)%acint
                              C_coeff_ss_a => alist_ac%clist(kac)%acint
                              dista = .TRUE.
                           END IF

                           IF (jatom == katom .AND. ALL(alist_bc%clist(kbc)%cell == 0)) THEN
                              C_coeff_hh_b => alist_bc%clist(kbc)%achint
                              C_coeff_ss_b => alist_bc%clist(kbc)%acint
                              distb = .FALSE.
                           ELSE
                              C_coeff_hh_b => alist_bc%clist(kbc)%acint
                              C_coeff_ss_b => alist_bc%clist(kbc)%acint
                              distb = .TRUE.
                           END IF

                           rho_at => my_rho_atom(katom)
                           nsoctot = SIZE(C_coeff_ss_a, 2)

                           CALL get_rho_atom(rho_atom=rho_at, int_scr_h=int_local_h, int_scr_s=int_local_s)
                           CALL add_vhxca_to_ks(mat_h, C_coeff_hh_a, C_coeff_hh_b, C_coeff_ss_a, C_coeff_ss_b, &
                                                int_local_h, int_local_s, nspins, iatom, jatom, nsoctot, factor1, &
                                                list_a, n_cont_a, list_b, n_cont_b, C_int_h, C_int_s, a_matrix, dista, distb, coc)

                           IF (forces) THEN
                              IF (use_virial) THEN
                                 rac = alist_ac%clist(kac)%rac
                                 rbc = alist_bc%clist(kbc)%rac
                              END IF
                              ia_kind = atom_of_kind(iatom)
                              ja_kind = atom_of_kind(jatom)
                              ka_kind = atom_of_kind(katom)
                              rho_at => my_rho_atom(katom)
                              force_tmp(1:3, 1:3) = 0.0_dp

                              CALL get_rho_atom(rho_atom=rho_at, int_scr_h=int_local_h, int_scr_s=int_local_s)
                              IF (iatom <= jatom) THEN
                                 CALL add_vhxca_forces(mat_p, C_coeff_hh_a, C_coeff_hh_b, C_coeff_ss_a, C_coeff_ss_b, &
                                                       int_local_h, int_local_s, force_tmp, nspins, iatom, jatom, nsoctot, &
                                                       list_a, n_cont_a, list_b, n_cont_b, dCPC_h, dCPC_s, ldCPC, &
                                                       PC_h, PC_s, p_matrix, force_fac)
                                 force_tmp = factor2*force_tmp
!$                               CALL omp_set_lock(locks((ka_kind - 1)*nkind + kkind))
                                 force(kkind)%vhxc_atom(1:3, ka_kind) = force(kkind)%vhxc_atom(1:3, ka_kind) + force_tmp(1:3, 3)
!$                               CALL omp_unset_lock(locks((ka_kind - 1)*nkind + kkind))
!$                               CALL omp_set_lock(locks((ia_kind - 1)*nkind + ikind))
                                 force(ikind)%vhxc_atom(1:3, ia_kind) = force(ikind)%vhxc_atom(1:3, ia_kind) + force_tmp(1:3, 1)
!$                               CALL omp_unset_lock(locks((ia_kind - 1)*nkind + ikind))
!$                               CALL omp_set_lock(locks((ja_kind - 1)*nkind + jkind))
                                 force(jkind)%vhxc_atom(1:3, ja_kind) = force(jkind)%vhxc_atom(1:3, ja_kind) + force_tmp(1:3, 2)
!$                               CALL omp_unset_lock(locks((ja_kind - 1)*nkind + jkind))
                                 IF (use_virial) THEN
                                    CALL virial_pair_force(pv_virial_thread, 1._dp, force_tmp(1:3, 1), rac)
                                    CALL virial_pair_force(pv_virial_thread, 1._dp, force_tmp(1:3, 2), rbc)
                                 END IF
                              ELSE
                                 CALL add_vhxca_forces(mat_p, C_coeff_hh_b, C_coeff_hh_a, C_coeff_ss_b, C_coeff_ss_a, &
                                                       int_local_h, int_local_s, force_tmp, nspins, jatom, iatom, nsoctot, &
                                                       list_b, n_cont_b, list_a, n_cont_a, dCPC_h, dCPC_s, ldCPC, &
                                                       PC_h, PC_s, p_matrix, force_fac)
                                 force_tmp = factor2*force_tmp
!$                               CALL omp_set_lock(locks((ka_kind - 1)*nkind + kkind))
                                 force(kkind)%vhxc_atom(1:3, ka_kind) = force(kkind)%vhxc_atom(1:3, ka_kind) + force_tmp(1:3, 3)
!$                               CALL omp_unset_lock(locks((ka_kind - 1)*nkind + kkind))
!$                               CALL omp_set_lock(locks((ia_kind - 1)*nkind + ikind))
                                 force(ikind)%vhxc_atom(1:3, ia_kind) = force(ikind)%vhxc_atom(1:3, ia_kind) + force_tmp(1:3, 2)
!$                               CALL omp_unset_lock(locks((ia_kind - 1)*nkind + ikind))
!$                               CALL omp_set_lock(locks((ja_kind - 1)*nkind + jkind))
                                 force(jkind)%vhxc_atom(1:3, ja_kind) = force(jkind)%vhxc_atom(1:3, ja_kind) + force_tmp(1:3, 1)
!$                               CALL omp_unset_lock(locks((ja_kind - 1)*nkind + jkind))
                                 IF (use_virial) THEN
                                    CALL virial_pair_force(pv_virial_thread, 1._dp, force_tmp(1:3, 2), rac)
                                    CALL virial_pair_force(pv_virial_thread, 1._dp, force_tmp(1:3, 1), rbc)
                                 END IF
                              END IF

                           END IF
                           EXIT ! search loop over jatom-katom list
                        END IF
                     END DO ! kbc
                  END DO ! kac

               END DO ! kkind

               next_task => task%next
               ! We are done with this task, we can deallocate it
               DEALLOCATE (task)
               is_task_valid = ASSOCIATED(next_task)
               IF (is_task_valid) task => next_task

            END DO

         ELSE
            ! NO KEY/VALUE
         END IF
      END DO
!$OMP END DO

      DO ispin = 1, nspins
         NULLIFY (mat_h(ispin)%array, mat_p(ispin)%array)
      END DO
      DEALLOCATE (mat_h, mat_p, C_int_h, C_int_s, a_matrix, p_matrix, coc)

      IF (forces) THEN
         DEALLOCATE (dCPC_h, dCPC_s, PC_h, PC_s)

         ! Implicit barrier at end of OMP DO, so locks can be freed
!$OMP DO
!$       DO lock_num = 1, natom*nkind
!$          call omp_destroy_lock(locks(lock_num))
!$       END DO
!$OMP END DO

!$OMP SINGLE
!$       DEALLOCATE (locks)
!$OMP END SINGLE NOWAIT
      END IF

!$OMP END PARALLEL

      IF (use_virial) THEN
         virial%pv_gapw(1:3, 1:3) = virial%pv_gapw(1:3, 1:3) + pv_virial_thread(1:3, 1:3)
         virial%pv_virial(1:3, 1:3) = virial%pv_virial(1:3, 1:3) + pv_virial_thread(1:3, 1:3)
      END IF

      CALL nl_hash_table_release(nl_hash_table)

      DEALLOCATE (basis_set_list)

      CALL timestop(handle)

   END SUBROUTINE update_ks_atom

! **************************************************************************************************
!> \brief ...
!> \param mat_h ...
!> \param C_hh_a ...
!> \param C_hh_b ...
!> \param C_ss_a ...
!> \param C_ss_b ...
!> \param int_local_h ...
!> \param int_local_s ...
!> \param nspins ...
!> \param ia ...
!> \param ja ...
!> \param nsp ...
!> \param factor ...
!> \param lista ...
!> \param nconta ...
!> \param listb ...
!> \param ncontb ...
!> \param C_int_h ...
!> \param C_int_s ...
!> \param a_matrix ...
!> \param dista ...
!> \param distb ...
!> \param coc ...
! **************************************************************************************************
   SUBROUTINE add_vhxca_to_ks(mat_h, C_hh_a, C_hh_b, C_ss_a, C_ss_b, &
                              int_local_h, int_local_s, &
                              nspins, ia, ja, nsp, factor, lista, nconta, listb, ncontb, &
                              C_int_h, C_int_s, a_matrix, dista, distb, coc)
      TYPE(cp_2d_r_p_type), DIMENSION(:), POINTER        :: mat_h
      REAL(KIND=dp), DIMENSION(:, :, :), INTENT(IN)      :: C_hh_a, C_hh_b, C_ss_a, C_ss_b
      TYPE(rho_atom_coeff), DIMENSION(:), POINTER        :: int_local_h, int_local_s
      INTEGER, INTENT(IN)                                :: nspins, ia, ja, nsp
      REAL(KIND=dp), INTENT(IN)                          :: factor
      INTEGER, DIMENSION(:), INTENT(IN)                  :: lista
      INTEGER, INTENT(IN)                                :: nconta
      INTEGER, DIMENSION(:), INTENT(IN)                  :: listb
      INTEGER, INTENT(IN)                                :: ncontb
      REAL(KIND=dp), DIMENSION(:), INTENT(OUT)           :: C_int_h, C_int_s
      REAL(dp), DIMENSION(:, :)                          :: a_matrix
      LOGICAL, INTENT(IN)                                :: dista, distb
      REAL(dp), DIMENSION(:)                             :: coc

      INTEGER                                            :: i, ispin, j, k
      REAL(KIND=dp), DIMENSION(:, :), POINTER            :: h_block, int_hard, int_soft

      NULLIFY (int_hard, int_soft)

      DO ispin = 1, nspins
         h_block => mat_h(ispin)%array
         !
         int_hard => int_local_h(ispin)%r_coef
         int_soft => int_local_s(ispin)%r_coef
         !
         IF (ia <= ja) THEN
            IF (dista .AND. distb) THEN
               k = 0
               DO i = 1, nsp
                  DO j = 1, nsp
                     k = k + 1
                     coc(k) = int_hard(j, i) - int_soft(j, i)
                  END DO
               END DO
               CALL DGEMM('N', 'T', nsp, ncontb, nsp, 1.0_dp, coc, nsp, C_hh_b(:, :, 1), SIZE(C_hh_b, 1), &
                          0.0_dp, C_int_h, nsp)
               CALL DGEMM('N', 'N', nconta, ncontb, nsp, factor, C_hh_a(:, :, 1), SIZE(C_hh_a, 1), &
                          C_int_h, nsp, 0.0_dp, a_matrix, SIZE(a_matrix, 1))
            ELSEIF (dista) THEN
               CALL DGEMM('N', 'T', nsp, ncontb, nsp, 1.0_dp, int_hard, SIZE(int_hard, 1), &
                          C_hh_b(:, :, 1), SIZE(C_hh_b, 1), 0.0_dp, coc, nsp)
               CALL DGEMM('N', 'T', nsp, ncontb, nsp, -1.0_dp, int_soft, SIZE(int_soft, 1), &
                          C_ss_b(:, :, 1), SIZE(C_ss_b, 1), 1.0_dp, coc, nsp)
               CALL DGEMM('N', 'N', nconta, ncontb, nsp, factor, C_hh_a(:, :, 1), SIZE(C_hh_a, 1), &
                          coc, nsp, 0.0_dp, a_matrix, SIZE(a_matrix, 1))
            ELSEIF (distb) THEN
               CALL DGEMM('N', 'N', nconta, nsp, nsp, factor, C_hh_a(:, :, 1), SIZE(C_hh_a, 1), &
                          int_hard, SIZE(int_hard, 1), 0.0_dp, coc, nconta)
               CALL DGEMM('N', 'N', nconta, nsp, nsp, -factor, C_ss_a(:, :, 1), SIZE(C_ss_a, 1), &
                          int_soft, SIZE(int_soft, 1), 1.0_dp, coc, nconta)
               CALL DGEMM('N', 'T', nconta, ncontb, nsp, 1.0_dp, coc, nconta, &
                          C_hh_b(:, :, 1), SIZE(C_hh_b, 1), 0.0_dp, a_matrix, SIZE(a_matrix, 1))
            ELSE
               CALL DGEMM('N', 'T', nsp, ncontb, nsp, 1.0_dp, int_hard, SIZE(int_hard, 1), &
                          C_hh_b(:, :, 1), SIZE(C_hh_b, 1), &
                          0.0_dp, C_int_h, nsp)
               CALL DGEMM('N', 'T', nsp, ncontb, nsp, 1.0_dp, int_soft, SIZE(int_soft, 1), &
                          C_ss_b(:, :, 1), SIZE(C_ss_b, 1), &
                          0.0_dp, C_int_s, nsp)
               CALL DGEMM('N', 'N', nconta, ncontb, nsp, factor, C_hh_a(:, :, 1), SIZE(C_hh_a, 1), &
                          C_int_h, nsp, &
                          0.0_dp, a_matrix, SIZE(a_matrix, 1))
               CALL DGEMM('N', 'N', nconta, ncontb, nsp, -factor, C_ss_a(:, :, 1), SIZE(C_ss_a, 1), &
                          C_int_s, nsp, &
                          1.0_dp, a_matrix, SIZE(a_matrix, 1))
            END IF
            !
            CALL alist_post_align_blk(a_matrix, SIZE(a_matrix, 1), h_block, SIZE(h_block, 1), &
                                      lista, nconta, listb, ncontb)
         ELSE
            IF (dista .AND. distb) THEN
               k = 0
               DO i = 1, nsp
                  DO j = 1, nsp
                     k = k + 1
                     coc(k) = int_hard(j, i) - int_soft(j, i)
                  END DO
               END DO
               CALL DGEMM('N', 'T', nsp, nconta, nsp, 1.0_dp, coc, nsp, C_hh_a(:, :, 1), SIZE(C_hh_a, 1), 0.0_dp, C_int_h, nsp)
               CALL DGEMM('N', 'N', ncontb, nconta, nsp, factor, C_hh_b(:, :, 1), SIZE(C_hh_b, 1), &
                          C_int_h, nsp, 0.0_dp, a_matrix, SIZE(a_matrix, 1))
            ELSEIF (dista) THEN
               CALL DGEMM('N', 'N', ncontb, nsp, nsp, factor, C_hh_b(:, :, 1), SIZE(C_hh_b, 1), &
                          int_hard, SIZE(int_hard, 1), 0.0_dp, coc, ncontb)
               CALL DGEMM('N', 'N', ncontb, nsp, nsp, -factor, C_ss_b(:, :, 1), SIZE(C_ss_b, 1), &
                          int_soft, SIZE(int_soft, 1), 1.0_dp, coc, ncontb)
               CALL DGEMM('N', 'T', ncontb, nconta, nsp, 1.0_dp, coc, ncontb, &
                          C_hh_a(:, :, 1), SIZE(C_hh_a, 1), 0.0_dp, a_matrix, SIZE(a_matrix, 1))
            ELSEIF (distb) THEN
               CALL DGEMM('N', 'T', nsp, nconta, nsp, 1.0_dp, int_hard, SIZE(int_hard, 1), &
                          C_hh_a(:, :, 1), SIZE(C_hh_a, 1), 0.0_dp, coc, nsp)
               CALL DGEMM('N', 'T', nsp, nconta, nsp, -1.0_dp, int_soft, SIZE(int_soft, 1), &
                          C_ss_a(:, :, 1), SIZE(C_ss_a, 1), 1.0_dp, coc, nsp)
               CALL DGEMM('N', 'N', ncontb, nconta, nsp, factor, C_hh_b(:, :, 1), SIZE(C_hh_b, 1), &
                          coc, nsp, 0.0_dp, a_matrix, SIZE(a_matrix, 1))
            ELSE
               CALL DGEMM('N', 'T', nsp, nconta, nsp, 1.0_dp, int_hard, SIZE(int_hard, 1), &
                          C_hh_a(:, :, 1), SIZE(C_hh_a, 1), &
                          0.0_dp, C_int_h, nsp)
               CALL DGEMM('N', 'T', nsp, nconta, nsp, 1.0_dp, int_soft, SIZE(int_soft, 1), &
                          C_ss_a(:, :, 1), SIZE(C_ss_a, 1), &
                          0.0_dp, C_int_s, nsp)
               CALL DGEMM('N', 'N', ncontb, nconta, nsp, factor, C_hh_b(:, :, 1), SIZE(C_hh_b, 1), &
                          C_int_h, nsp, &
                          0.0_dp, a_matrix, SIZE(a_matrix, 1))
               CALL DGEMM('N', 'N', ncontb, nconta, nsp, -factor, C_ss_b(:, :, 1), SIZE(C_ss_b, 1), &
                          C_int_s, nsp, &
                          1.0_dp, a_matrix, SIZE(a_matrix, 1))
            END IF
            !
            CALL alist_post_align_blk(a_matrix, SIZE(a_matrix, 1), h_block, SIZE(h_block, 1), &
                                      listb, ncontb, lista, nconta)
         END IF
      END DO

   END SUBROUTINE add_vhxca_to_ks

! **************************************************************************************************
!> \brief ...
!> \param mat_p ...
!> \param C_hh_a ...
!> \param C_hh_b ...
!> \param C_ss_a ...
!> \param C_ss_b ...
!> \param int_local_h ...
!> \param int_local_s ...
!> \param force ...
!> \param nspins ...
!> \param ia ...
!> \param ja ...
!> \param nsp ...
!> \param lista ...
!> \param nconta ...
!> \param listb ...
!> \param ncontb ...
!> \param dCPC_h ...
!> \param dCPC_s ...
!> \param ldCPC ...
!> \param PC_h ...
!> \param PC_s ...
!> \param p_matrix ...
!> \param force_scaling ...
! **************************************************************************************************
   SUBROUTINE add_vhxca_forces(mat_p, C_hh_a, C_hh_b, C_ss_a, C_ss_b, &
                               int_local_h, int_local_s, &
                               force, nspins, ia, ja, nsp, lista, nconta, listb, ncontb, &
                               dCPC_h, dCPC_s, ldCPC, PC_h, PC_s, p_matrix, force_scaling)
      TYPE(cp_2d_r_p_type), DIMENSION(:), INTENT(IN), &
         POINTER                                         :: mat_p
      REAL(KIND=dp), DIMENSION(:, :, :), INTENT(IN)      :: C_hh_a, C_hh_b, C_ss_a, C_ss_b
      TYPE(rho_atom_coeff), DIMENSION(:), POINTER        :: int_local_h, int_local_s
      REAL(dp), DIMENSION(3, 3), INTENT(INOUT)           :: force
      INTEGER, INTENT(IN)                                :: nspins, ia, ja, nsp
      INTEGER, DIMENSION(:), INTENT(IN)                  :: lista
      INTEGER, INTENT(IN)                                :: nconta
      INTEGER, DIMENSION(:), INTENT(IN)                  :: listb
      INTEGER, INTENT(IN)                                :: ncontb
      REAL(KIND=dp), DIMENSION(:, :)                     :: dCPC_h, dCPC_s
      INTEGER, INTENT(IN)                                :: ldCPC
      REAL(KIND=dp), DIMENSION(:, :, :)                  :: PC_h, PC_s
      REAL(KIND=dp), DIMENSION(:, :)                     :: p_matrix
      REAL(KIND=dp), DIMENSION(2), INTENT(IN)            :: force_scaling

      INTEGER                                            :: dir, ispin
      REAL(dp), DIMENSION(:, :), POINTER                 :: int_hard, int_soft
      REAL(KIND=dp)                                      :: ieqj, trace
      REAL(KIND=dp), DIMENSION(:, :), POINTER            :: p_block

!   if(dista.and.distb) we could also here use ChPCh = CsPCs
!   *** factor 2 because only half of the pairs with ia =/ ja are considered

      ieqj = 2.0_dp
      IF (ia == ja) ieqj = 1.0_dp

      NULLIFY (int_hard, int_soft)

      DO ispin = 1, nspins
         p_block => mat_p(ispin)%array

         CALL alist_pre_align_blk(p_block, SIZE(p_block, 1), p_matrix, SIZE(p_matrix, 1), &
                                  lista, nconta, listb, ncontb)

         int_hard => int_local_h(ispin)%r_coef
         int_soft => int_local_s(ispin)%r_coef

         CALL DGEMM('N', 'N', nconta, nsp, ncontb, 1.0_dp, p_matrix, SIZE(p_matrix, 1), &
                    C_hh_b(:, :, 1), SIZE(C_hh_b, 1), &
                    0.0_dp, PC_h(:, :, ispin), SIZE(PC_h, 1))
         CALL DGEMM('N', 'N', nconta, nsp, ncontb, 1.0_dp, p_matrix(:, :), SIZE(p_matrix, 1), &
                    C_ss_b(:, :, 1), SIZE(C_ss_b, 1), &
                    0.0_dp, PC_s(:, :, ispin), SIZE(PC_s, 1))

         DO dir = 2, 4
            CALL DGEMM('T', 'N', nsp, nsp, nconta, 1.0_dp, PC_h(:, :, ispin), SIZE(PC_h, 1), &
                       C_hh_a(:, :, dir), SIZE(C_hh_a, 1), &
                       0.0_dp, dCPC_h, SIZE(dCPC_h, 1))
            trace = trace_r_AxB(dCPC_h, ldCPC, int_hard, nsp, nsp, nsp)
            force(dir - 1, 3) = force(dir - 1, 3) + ieqj*trace*force_scaling(ispin)
            force(dir - 1, 1) = force(dir - 1, 1) - ieqj*trace*force_scaling(ispin)

            CALL DGEMM('T', 'N', nsp, nsp, nconta, 1.0_dp, PC_s(:, :, ispin), SIZE(PC_s, 1), &
                       C_ss_a(:, :, dir), SIZE(C_ss_a, 1), &
                       0.0_dp, dCPC_s, SIZE(dCPC_s, 1))
            trace = trace_r_AxB(dCPC_s, ldCPC, int_soft, nsp, nsp, nsp)
            force(dir - 1, 3) = force(dir - 1, 3) - ieqj*trace*force_scaling(ispin)
            force(dir - 1, 1) = force(dir - 1, 1) + ieqj*trace*force_scaling(ispin)
         END DO

         ! j-k contributions
         CALL DGEMM('T', 'N', ncontb, nsp, nconta, 1.0_dp, p_matrix, SIZE(p_matrix, 1), &
                    C_hh_a(:, :, 1), SIZE(C_hh_a, 1), &
                    0.0_dp, PC_h(:, :, ispin), SIZE(PC_h, 1))
         CALL DGEMM('T', 'N', ncontb, nsp, nconta, 1.0_dp, p_matrix, SIZE(p_matrix, 1), &
                    C_ss_a(:, :, 1), SIZE(C_ss_a, 1), &
                    0.0_dp, PC_s(:, :, ispin), SIZE(PC_s, 1))

         DO dir = 2, 4
            CALL DGEMM('T', 'N', nsp, nsp, ncontb, 1.0_dp, PC_h(:, :, ispin), SIZE(PC_h, 1), &
                       C_hh_b(:, :, dir), SIZE(C_hh_b, 1), &
                       0.0_dp, dCPC_h, SIZE(dCPC_h, 1))
            trace = trace_r_AxB(dCPC_h, ldCPC, int_hard, nsp, nsp, nsp)
            force(dir - 1, 3) = force(dir - 1, 3) + ieqj*trace*force_scaling(ispin)
            force(dir - 1, 2) = force(dir - 1, 2) - ieqj*trace*force_scaling(ispin)

            CALL DGEMM('T', 'N', nsp, nsp, ncontb, 1.0_dp, PC_s(:, :, ispin), SIZE(PC_s, 1), &
                       C_ss_b(:, :, dir), SIZE(C_ss_b, 1), &
                       0.0_dp, dCPC_s, SIZE(dCPC_s, 1))
            trace = trace_r_AxB(dCPC_s, ldCPC, int_soft, nsp, nsp, nsp)
            force(dir - 1, 3) = force(dir - 1, 3) - ieqj*trace*force_scaling(ispin)
            force(dir - 1, 2) = force(dir - 1, 2) + ieqj*trace*force_scaling(ispin)
         END DO

      END DO !ispin

   END SUBROUTINE add_vhxca_forces

END MODULE qs_ks_atom
